from cProfile import label
import math
from re import RegexFlag
from typing import List, Optional, Tuple, Union

import numpy as np
from numpy.linalg import norm
import torch
from torch import nn
from torch.nn import CrossEntropyLoss, MSELoss
from transformers.modeling_outputs import (
    BaseModelOutputWithPast,
    CausalLMOutputWithPast,
)
from transformers.models.llama import LlamaPreTrainedModel
from transformers.models.llama.configuration_llama import LlamaConfig
from transformers.models.llama.modeling_llama import (
    LlamaAttention,
    LlamaDecoderLayer,
    LlamaMLP,
    LlamaRMSNorm,
    apply_rotary_pos_emb,
)
from transformers.utils import logging

from .generation_utils import GistGenerationMixin
from .gist_caching import GistActivations

logger = logging.get_logger(__name__)


PRETRAINED_VOCAB_SIZE = 32000
AVG_NORM = 1.28
DEBUG_LLAMA_CONFIG = LlamaConfig(
    vocab_size=PRETRAINED_VOCAB_SIZE,
    hidden_size=4096,
    intermediate_size=1024,
    num_hidden_layers=2,
    num_attention_heads=32,
)
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')


def _make_causal_mask(
    input_ids_shape: torch.Size, dtype: torch.dtype, past_key_values_length: int = 0
):
    """
    Make causal mask used for bi-directional self-attention.
    """
    bsz, tgt_len = input_ids_shape
    mask = torch.full((tgt_len, tgt_len), torch.tensor(torch.finfo(dtype).min))
    mask_cond = torch.arange(mask.size(-1))
    mask.masked_fill_(mask_cond < (mask_cond + 1).view(mask.size(-1), 1), 0)
    mask = mask.to(dtype)

    if past_key_values_length > 0:
        mask = torch.cat(
            [torch.zeros(tgt_len, past_key_values_length, dtype=dtype), mask], dim=-1
        )
    return mask[None, None, :, :].expand(
        bsz, 1, tgt_len, tgt_len + past_key_values_length
    )


def _expand_mask(mask: torch.Tensor, dtype: torch.dtype, tgt_len: Optional[int] = None):
    """
    Expands attention_mask from `[bsz, seq_len]` to `[bsz, 1, tgt_seq_len,
    src_seq_len]`.
    """
    bsz, src_len = mask.size()
    tgt_len = tgt_len if tgt_len is not None else src_len

    expanded_mask = mask[:, None, None, :].expand(bsz, 1, tgt_len, src_len).to(dtype)

    inverted_mask = 1.0 - expanded_mask

    return inverted_mask.masked_fill(
        inverted_mask.to(torch.bool), torch.finfo(dtype).min
    )


class GistLlamaRotaryEmbedding(torch.nn.Module):
    def __init__(self, dim, max_position_embeddings=2048, base=10000, device=None):
        super().__init__()
        inv_freq = 1.0 / (base ** (torch.arange(0, dim, 2).float().to(device) / dim))
        self.register_buffer("inv_freq", inv_freq)

        # Build here to make `torch.jit.trace` work.
        self.max_seq_len_cached = max_position_embeddings
        t = torch.arange(
            self.max_seq_len_cached,
            device=self.inv_freq.device,
            dtype=self.inv_freq.dtype,
        )
        freqs = torch.einsum("i,j->ij", t, self.inv_freq)
        # Different from paper, but it uses a different permutation in order to
        # obtain the same calculation
        emb = torch.cat((freqs, freqs), dim=-1)
        self.cos_cached = emb.cos()[None, None, :, :]
        self.sin_cached = emb.sin()[None, None, :, :]

    def forward(self, x, seq_len=None, gist_offset: Optional[torch.LongTensor] = None):
        # x: [bs, num_attention_heads, seq_len, head_size]
        # This `if` block is unlikely to be run after we build sin/cos in
        # `__init__`. Keep the logic here just in case.
        if seq_len > self.max_seq_len_cached:
            self.max_seq_len_cached = seq_len
            t = torch.arange(
                self.max_seq_len_cached, device=x.device, dtype=self.inv_freq.dtype
            )
            freqs = torch.einsum("i,j->ij", t, self.inv_freq)
            # Different from paper, but it uses a different permutation in order
            # to obtain the same calculation
            emb = torch.cat((freqs, freqs), dim=-1).to(x.device)
            self.cos_cached = emb.cos()[None, None, :, :].to(dtype=x.dtype)
            self.sin_cached = emb.sin()[None, None, :, :].to(dtype=x.dtype)
        offset = 0
        if gist_offset is not None:
            # XXX: THIS WILL BREAK FOR BATCH SIZE > 1.
            offset = gist_offset[0].item()
        return (
            self.cos_cached[:, :, offset : offset + seq_len, ...].to(
                dtype=x.dtype, device=x.device
            ),
            self.sin_cached[:, :, offset : offset + seq_len, ...].to(
                dtype=x.dtype, device=x.device
            ),
        )


class GistLlamaAttention(LlamaAttention):
    """Multi-headed attention from 'Attention Is All You Need' paper"""

    def __init__(
        self,
        hidden_size: int,
        num_heads: int,
    ):
        super(LlamaAttention, self).__init__()
        self.hidden_size = hidden_size
        self.num_heads = num_heads
        self.head_dim = hidden_size // num_heads

        if (self.head_dim * num_heads) != self.hidden_size:
            raise ValueError(
                "hidden_size must be divisible by num_heads (got "
                f"`hidden_size`: {self.hidden_size}"
                f" and `num_heads`: {num_heads})."
            )
        self.q_proj = nn.Linear(
            hidden_size,
            num_heads * self.head_dim,
            bias=False,
        )
        self.k_proj = nn.Linear(
            hidden_size,
            num_heads * self.head_dim,
            bias=False,
        )
        self.v_proj = nn.Linear(
            hidden_size,
            num_heads * self.head_dim,
            bias=False,
        )
        self.o_proj = nn.Linear(
            num_heads * self.head_dim,
            hidden_size,
            bias=False,
        )
        self.rotary_emb = GistLlamaRotaryEmbedding(self.head_dim)

    def forward(
        self,
        hidden_states: torch.Tensor,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: bool = False,
        use_cache: bool = False,
    ) -> Tuple[torch.Tensor, Optional[torch.Tensor], Optional[Tuple[torch.Tensor]]]:
        """Input shape: Batch x Time x Channel"""

        bsz, q_len, _ = hidden_states.size()

        query_states = (
            self.q_proj(hidden_states)
            .view(bsz, q_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        key_states = (
            self.k_proj(hidden_states)
            .view(bsz, q_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )
        value_states = (
            self.v_proj(hidden_states)
            .view(bsz, q_len, self.num_heads, self.head_dim)
            .transpose(1, 2)
        )

        kv_seq_len = key_states.shape[-2]
        offset = 0
        if past_key_value is not None:
            offset = past_key_value[0].shape[-2]
            kv_seq_len += offset
        cos, sin = self.rotary_emb(
            value_states, seq_len=kv_seq_len
        )
        query_states, key_states = apply_rotary_pos_emb(
            query_states, key_states, cos, sin, offset=offset
        )

        if past_key_value is not None:
            # reuse k, v, self_attention
            key_states = torch.cat([past_key_value[0], key_states], dim=2)
            value_states = torch.cat([past_key_value[1], value_states], dim=2)

        past_key_value = (key_states, value_states) if use_cache else None

        attn_weights = torch.matmul(
            query_states, key_states.transpose(2, 3)
        ) / math.sqrt(self.head_dim)

        if attn_weights.size() != (bsz, self.num_heads, q_len, kv_seq_len):
            raise ValueError(
                "Attention weights should be of size "
                f"{(bsz * self.num_heads, q_len, kv_seq_len)}, but is"
                f" {attn_weights.size()}"
            )

        if attention_mask is not None:
            if attention_mask.size() != (bsz, 1, q_len, kv_seq_len):
                raise ValueError(
                    f"Attention mask should be of size "
                    f"{(bsz, 1, q_len, kv_seq_len)}, but is {attention_mask.size()}"
                )
            attn_weights = attn_weights + attention_mask
            attn_weights = torch.max(
                attn_weights, torch.tensor(torch.finfo(attn_weights.dtype).min)
            )

        # upcast attention to fp32
        attn_weights = nn.functional.softmax(
            attn_weights, dim=-1, dtype=torch.float32
        ).to(query_states.dtype)
        attn_output = torch.matmul(attn_weights, value_states)

        if attn_output.size() != (bsz, self.num_heads, q_len, self.head_dim):
            raise ValueError(
                "`attn_output` should be of size "
                f"{(bsz, self.num_heads, q_len, self.head_dim)}, but is"
                f" {attn_output.size()}"
            )

        attn_output = attn_output.transpose(1, 2)
        attn_output = attn_output.reshape(bsz, q_len, self.hidden_size)

        attn_output = self.o_proj(attn_output)

        if not output_attentions:
            attn_weights = None

        return attn_output, attn_weights, past_key_value


class GistLlamaDecoderLayer(LlamaDecoderLayer):
    def __init__(self, config: LlamaConfig):
        super(LlamaDecoderLayer, self).__init__()
        self.hidden_size = config.hidden_size
        self.self_attn = GistLlamaAttention(
            hidden_size=self.hidden_size,
            num_heads=config.num_attention_heads,
        )
        self.mlp = LlamaMLP(
            hidden_size=self.hidden_size,
            intermediate_size=config.intermediate_size,
            hidden_act=config.hidden_act,
        )
        self.input_layernorm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)
        self.post_attention_layernorm = LlamaRMSNorm(
            config.hidden_size, eps=config.rms_norm_eps
        )

    def forward(
        self,
        hidden_states: torch.Tensor,
        attention_mask: Optional[torch.Tensor] = None,
        output_attentions: Optional[bool] = False,
        use_cache: Optional[bool] = False,
        past_key_value: Optional[Tuple[torch.Tensor]] = None,
    ) -> Tuple[
        torch.FloatTensor, Optional[Tuple[torch.FloatTensor, torch.FloatTensor]]
    ]:
        """
        Args:
            hidden_states (`torch.FloatTensor`): input to the layer of shape
                `(batch, seq_len, embed_dim)`
            attention_mask (`torch.FloatTensor`, *optional*): attention mask of
                size `(batch, 1, tgt_len, src_len)` where padding elements are
                indicated by very large negative values.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention
                layers. See `attentions` under returned tensors for more detail.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are
                returned and can be used to speed up decoding (see
                `past_key_values`).
            past_key_value (`Tuple(torch.FloatTensor)`, *optional*): cached past
                key and value projection states
        """

        residual = hidden_states

        hidden_states = self.input_layernorm(hidden_states)

        # Self Attention
        hidden_states, self_attn_weights, present_key_value = self.self_attn(
            hidden_states=hidden_states,
            past_key_value=past_key_value,
            attention_mask=attention_mask,
            output_attentions=output_attentions,
            use_cache=use_cache,
        )
        hidden_states = residual + hidden_states

        # Fully Connected
        residual = hidden_states
        hidden_states = self.post_attention_layernorm(hidden_states)
        hidden_states = self.mlp(hidden_states)
        hidden_states = residual + hidden_states

        outputs = (hidden_states,)

        if output_attentions:
            outputs += (self_attn_weights,)

        if use_cache:
            outputs += (present_key_value,)

        return outputs

###### Token embedding ######
class AugmentedTokenEncoder(torch.nn.Module):
    def __init__(self, config, embedder, num_new_tokens):
        super().__init__()

        self.token_dim = config.hidden_size
        self.original_embedder = embedder

        self.embedding = nn.Embedding(num_new_tokens, self.token_dim)

        self.vocab_size = PRETRAINED_VOCAB_SIZE
        self.added_tokens = [self.vocab_size + i for i in range(num_new_tokens)]
        self.encode_reg = nn.Linear(config.regression_out_dim, self.token_dim, bias=False)


    def forward(self, input_ids: torch.Tensor, icl_labels: Optional[List[float]] = None):

        embeddings = torch.zeros(
                (*input_ids.shape, self.token_dim),
                dtype=self.embedding.weight.data.dtype,
                device=input_ids.device,
            )
	    
        # Special tokens
        token_indices = torch.zeros_like(input_ids)

        for tid in self.added_tokens:
            token_indices = token_indices + (input_ids == tid)

        token_indices = token_indices.bool()
        text_indices = ~token_indices

        if token_indices.sum() > 0:
            token_ids = input_ids[token_indices] - self.vocab_size
            token_embeddings = self.embedding(token_ids)
            embeddings[token_indices] = token_embeddings

	    # Text
        text_embeddings = self.original_embedder(input_ids[text_indices])
        embeddings[text_indices] = text_embeddings

        # Scalar encoding
        if icl_labels is not None:
            icl_indices = input_ids == 0
            if icl_indices.sum() > 0:
                icl_labels = icl_labels.to(embeddings.dtype)
                h = self.encode_reg(icl_labels.unsqueeze(-1)).to(embeddings.dtype)
                embeddings[icl_indices] = h
                
                #print("ICL",icl_labels,input_ids[icl_indices],embeddings[icl_indices],self.encode_reg.weight.data)

        return embeddings


class GistLlamaModel(LlamaPreTrainedModel):
    """
    Transformer decoder consisting of *config.num_hidden_layers* layers. Each
    layer is a [`LlamaDecoderLayer`]

    Args:
        config: LlamaConfig
    """

    def __init__(self, config: LlamaConfig):
        super().__init__(config)
        self.padding_idx = config.pad_token_id
        self.vocab_size = config.vocab_size

        self.embed_tokens = nn.Embedding(
            config.vocab_size, config.hidden_size, self.padding_idx
        )

        if config.num_new_tokens > 0:
            self.augmented_embedder = AugmentedTokenEncoder(config, self.embed_tokens, config.num_new_tokens)

        self.layers = nn.ModuleList(
            [GistLlamaDecoderLayer(config) for _ in range(config.num_hidden_layers)]
        )
        self.norm = LlamaRMSNorm(config.hidden_size, eps=config.rms_norm_eps)

        self.gradient_checkpointing = False

        # Initialize weights and apply final processing
        self.post_init()

    def get_input_embeddings(self):
        return self.embed_tokens

    def set_input_embeddings(self, value):
        self.embed_tokens = value

    # Copied from transformers.models.bart.modeling_bart.BartDecoder._prepare_decoder_attention_mask
    def _prepare_decoder_attention_mask(
        self, attention_mask, input_shape, inputs_embeds, past_key_values_length
    ):
        # create causal mask
        # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
        combined_attention_mask = None
        if input_shape[-1] > 1:
            combined_attention_mask = _make_causal_mask(
                input_shape,
                inputs_embeds.dtype,
                past_key_values_length=past_key_values_length,
            ).to(inputs_embeds.device)

        if attention_mask is not None:
            # [bsz, seq_len] -> [bsz, 1, tgt_seq_len, src_seq_len]
            expanded_attn_mask = _expand_mask(
                attention_mask, inputs_embeds.dtype, tgt_len=input_shape[-1]+past_key_values_length
            ).to(inputs_embeds.device)

            if past_key_values_length > 0:
                expanded_attn_mask = expanded_attn_mask.transpose(-1, -2)

            combined_attention_mask = (
                expanded_attn_mask
                if combined_attention_mask is None
                else expanded_attn_mask + combined_attention_mask
            )

        return combined_attention_mask

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_gist: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        gist_activations: Optional[GistActivations] = None,
        icl_labels: Optional[List[float]] = None,
        return_dict: Optional[bool] = None,
        prev_hidden:Optional[torch.FloatTensor] = None,
    ) -> Union[Tuple, BaseModelOutputWithPast]:
        r"""
        Args:
            input_ids (`torch.LongTensor` of shape `(batch_size, sequence_length)`):
                Indices of input sequence tokens in the vocabulary. Padding will
                be ignored by default should you provide it.

                Indices can be obtained using [`AutoTokenizer`]. See
                [`PreTrainedTokenizer.encode`] and
                [`PreTrainedTokenizer.__call__`] for details.

                [What are input IDs?](../glossary#input-ids)
            attention_mask (`torch.Tensor` of shape `(batch_size,
                    sequence_length)`, *optional*):
                Mask to avoid performing attention on padding token indices.
                Mask values selected in `[0, 1]`:

                - 1 for tokens that are **not masked**,
                - 0 for tokens that are **masked**.

                [What are attention masks?](../glossary#attention-mask)
            past_key_values (`tuple(tuple(torch.FloatTensor))`, *optional*,
                    returned when `use_cache=True` is passed or when
                    `config.use_cache=True`):
                Tuple of `tuple(torch.FloatTensor)` of length `config.n_layers`,
                with each tuple having 2 tensors of shape `(batch_size,
                num_heads, sequence_length, embed_size_per_head)`) and 2
                additional tensors of

                Contains pre-computed hidden-states (key and values in the
                self-attention blocks and in the cross-attention blocks) that
                can be used (see `past_key_values` input) to speed up sequential
                decoding.

                If `past_key_values` are used, the user can optionally input
                only the last `decoder_input_ids` (those that don't have their
                past key value states given to this model) of shape
                `(batch_size, 1)` instead of all `decoder_input_ids` of shape
                `(batch_size, sequence_length)`.
            use_cache (`bool`, *optional*):
                If set to `True`, `past_key_values` key value states are
                returned and can be used to speed up decoding (see
                `past_key_values`).
            inputs_embeds (`torch.FloatTensor` of shape `(batch_size,
                    sequence_length, hidden_size)`, *optional*):
                Optionally, instead of passing `input_ids` you can choose to
                directly pass an embedded representation.  This is useful if you
                want more control over how to convert `input_ids` indices into
                associated vectors than the model's internal embedding lookup
                matrix.
            output_attentions (`bool`, *optional*):
                Whether or not to return the attentions tensors of all attention
                layers. See `attentions` under returned tensors for more detail.
            output_hidden_states (`bool`, *optional*):
                Whether or not to return the hidden states of all layers. See
                `hidden_states` under returned tensors for more detail.
            return_dict (`bool`, *optional*):
                Whether or not to return a [`~utils.ModelOutput`] instead of a
                plain tuple.
        """
        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        use_cache = use_cache if use_cache is not None else self.config.use_cache

        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        # retrieve input_ids and inputs_embeds
        if input_ids is not None and inputs_embeds is not None:
            raise ValueError(
                "You cannot specify both decoder_input_ids and "
                "decoder_inputs_embeds at the same time"
            )
        elif input_ids is not None:
            batch_size, seq_length = input_ids.shape
        elif inputs_embeds is not None:
            batch_size, seq_length, _ = inputs_embeds.shape
        else:
            raise ValueError(
                "You have to specify either decoder_input_ids or decoder_inputs_embeds"
            )

        seq_length_with_past = seq_length
        past_key_values_length = 0
        if past_key_values is not None:
            past_key_values_length = past_key_values[0][0].shape[2]
            seq_length_with_past = seq_length_with_past + past_key_values_length
        if inputs_embeds is None:
            inputs_embeds = self.augmented_embedder(input_ids, icl_labels)

        # embed positions
        if attention_mask is None:
            attention_mask = torch.ones(
                (batch_size, seq_length_with_past),
                dtype=torch.bool,
                device=inputs_embeds.device,
            )
        attention_mask = self._prepare_decoder_attention_mask(
            attention_mask,
            (batch_size, seq_length),
            inputs_embeds,
            past_key_values_length,
        )

        if attention_mask_gist is not None:
            attention_mask_gist_float = torch.full_like(
                attention_mask, 0.0
            )
            attention_mask_gist_float = attention_mask_gist_float.masked_fill(
                attention_mask_gist.bool(), torch.finfo(attention_mask.dtype).min
            )
            attention_mask = attention_mask + attention_mask_gist_float

        hidden_states = inputs_embeds

        if self.gradient_checkpointing and self.training:
            if use_cache:
                logger.warning_once(
                    "`use_cache=True` is incompatible with gradient checkpointing. "
                    "Setting `use_cache=False`..."
                )
                use_cache = False

        # decoder layers
        all_hidden_states = () if output_hidden_states else None
        all_self_attns = () if output_attentions else None
        next_decoder_cache = () if use_cache else None

        for idx, decoder_layer in enumerate(self.layers):
            if output_hidden_states:
                all_hidden_states += (hidden_states,)

            past_key_value = (
                past_key_values[idx] if past_key_values is not None else None
            )

            if self.gradient_checkpointing and self.training:

                def create_custom_forward(module):
                    def custom_forward(*inputs):
                        # None for past_key_value
                        return module(*inputs, output_attentions, None)

                    return custom_forward

                layer_outputs = torch.utils.checkpoint.checkpoint(
                    create_custom_forward(decoder_layer),
                    hidden_states,
                    attention_mask,
                    None,
                )
            else:
                layer_outputs = decoder_layer(
                    hidden_states,
                    attention_mask=attention_mask,
                    past_key_value=past_key_value,
                    output_attentions=output_attentions,
                    use_cache=use_cache,
                )

            hidden_states = layer_outputs[0]

            if use_cache:
                next_decoder_cache += (layer_outputs[2 if output_attentions else 1],)

            if output_attentions:
                all_self_attns += (layer_outputs[1],)

        hidden_states = self.norm(hidden_states)

        # add hidden states from the last decoder layer
        if output_hidden_states:
            all_hidden_states += (hidden_states,)

        next_cache = next_decoder_cache if use_cache else None
        if not return_dict:
            return tuple(
                v
                for v in [hidden_states, next_cache, all_hidden_states, all_self_attns]
                if v is not None
            )
        return BaseModelOutputWithPast(
            last_hidden_state=hidden_states,
            past_key_values=next_cache,
            hidden_states=all_hidden_states,
            attentions=all_self_attns,
        )


class GistLlamaForCausalLM(LlamaPreTrainedModel, GistGenerationMixin):
    _keys_to_ignore_on_load_missing = [r"lm_head.weight"]

    def __init__(self, config):
        super().__init__(config)
        self.model = GistLlamaModel(config)

        self.lm_head = nn.Linear(config.hidden_size, config.vocab_size, bias=False)

        self.lm_head_reg = None
        self.word_embeddings = None

        # Initialize weights and apply final processing
        self.post_init()

        self.step_for_save = 0
        self.regl = 0
        self.regstep = 0
        self.clfl = 0
        self.clfstep = 0
        self.output_dir = config.output_dir

        self.manual_eval = False
        self.pred_val = []
        self.true_val = []
        self.embeds = []


    @torch.no_grad()
    def get_gist_activations(
        self,
        input_ids: torch.LongTensor,
        attention_mask: torch.FloatTensor,
        attention_mask_gist: torch.FloatTensor,
        gist_token: int,
        num_gist_tokens: int,
        cache_all: bool = False,
    ) -> GistActivations:
        model_outputs = self.model(
            input_ids,
            attention_mask=attention_mask,
            attention_mask_gist=attention_mask_gist,
            output_hidden_states=True,
            use_cache=True,
        )
        return GistActivations.from_model_outputs(
            model_outputs=model_outputs,
            input_ids=input_ids,
            gist_token=gist_token,
            num_gist_tokens=num_gist_tokens,
            cache_all=cache_all,
        )

    def get_input_embeddings(self):
        return self.model.embed_tokens

    def set_input_embeddings(self, value):
        self.model.embed_tokens = value

    def get_output_embeddings(self):
        return self.lm_head

    def set_output_embeddings(self, new_embeddings):
        self.lm_head = new_embeddings

    def set_decoder(self, decoder):
        self.model = decoder

    def get_decoder(self):
        return self.model

    def forward(
        self,
        input_ids: torch.LongTensor = None,
        attention_mask: Optional[torch.Tensor] = None,
        attention_mask_gist: Optional[torch.Tensor] = None,
        past_key_values: Optional[List[torch.FloatTensor]] = None,
        inputs_embeds: Optional[torch.FloatTensor] = None,
        labels: Optional[torch.LongTensor] = None,
        label_reg: Optional[torch.FloatTensor] = None,
        reg_idx: Optional[List[int]] = None,
        reg_dim: Optional[List[int]] = None,
        clf_idx: Optional[List[int]] = None,
        reg_pred_idx: Optional[List[int]] = None,
        icl_labels: Optional[List[float]] = None,
        use_cache: Optional[bool] = None,
        output_attentions: Optional[bool] = None,
        output_hidden_states: Optional[bool] = None,
        return_dict: Optional[bool] = None,
        gist_activations: Optional[GistActivations] = None,
        manual_eval: bool = False,
        **kwargs,
    ) -> Union[Tuple, CausalLMOutputWithPast]:

        output_attentions = (
            output_attentions
            if output_attentions is not None
            else self.config.output_attentions
        )
        output_hidden_states = (
            output_hidden_states
            if output_hidden_states is not None
            else self.config.output_hidden_states
        )
        return_dict = (
            return_dict if return_dict is not None else self.config.use_return_dict
        )

        outputs = self.model(
            input_ids=input_ids,
            attention_mask=attention_mask,
            attention_mask_gist=attention_mask_gist,
            past_key_values=past_key_values,
            inputs_embeds=inputs_embeds,
            use_cache=use_cache,
            output_attentions=output_attentions,
            output_hidden_states=output_hidden_states,
            return_dict=return_dict,
            gist_activations=gist_activations,
            icl_labels=icl_labels
        )

        hidden_states = outputs[0]

        logits = None
        loss = None

        if labels is not None:
            loss = 0
            if reg_idx is not None:

                logits_reg = self.lm_head_reg(hidden_states[reg_idx, reg_pred_idx, :])
                reg_val = logits_reg[torch.arange(logits_reg.shape[0]), reg_dim]
                
                loss_fct = MSELoss()
                loss = loss + loss_fct(reg_val, label_reg) * 100

                self.regl = (self.regl * self.regstep + loss.item()) / (self.regstep + 1)
                self.regstep += 1

                if self.manual_eval:
                    if len(self.pred_val) == 0:
                        print("saving predictions")
                    reg_detach = reg_val.detach().cpu().numpy()
                    lab_detach = label_reg.detach().cpu().numpy()
                    # hidden_detach = hidden_states.detach().cpu().numpy()
                    for i in range(label_reg.shape[0]):
                        self.pred_val.append(reg_detach[i])
                        self.true_val.append(lab_detach[i])
                        # self.embeds.append(hidden_detach[i, -1])
        
            if clf_idx is not None:
                logits = self.lm_head(hidden_states[clf_idx, ...])
                labels = labels[clf_idx, ...]
                shift_logits = logits[..., :-1, :].contiguous()
                shift_labels = labels[..., 1:].contiguous()

                if shift_logits.shape[:2] != shift_labels.shape:
                    shift_labels = shift_labels[..., :shift_logits.shape[1]].contiguous()
                
                loss_fct = CrossEntropyLoss()
                loss_clf = loss_fct(
                    shift_logits.view(-1, self.config.vocab_size), shift_labels.view(-1)
                )
                loss = loss + loss_clf
                self.clfl = (self.clfl * self.clfstep + loss_clf.item()) / (self.clfstep + 1)
                self.clfstep += 1
    			
            # norm regularization
            loss += ((torch.linalg.norm(self.model.augmented_embedder.embedding.weight.data,dim=1)-AVG_NORM)**2).mean()

        else:
            loss = None
            if label_reg is not None:
                logits = self.lm_head_reg(hidden_states[...,-1, :])
            else:
                logits = self.lm_head(hidden_states)

        if not return_dict:
            output = (logits,) + outputs[1:]
            return (loss,) + output if loss is not None else output

        if self.manual_eval and output_attentions:
            attn = [outputs.attentions[i].detach().cpu().numpy() for i in range(len(outputs.attentions))]
            np.save(self.output_dir + "/input.npy", input_ids.detach().cpu().numpy())
            np.save(self.output_dir + "/attn.npy", attn)
            exit()

        if self.step_for_save % 4000 == 0:
            print("Step", self.step_for_save, "\tAccumulated MSE:", self.regl, "\tAccumulated CE:", self.clfl)#, torch.linalg.norm((self.model.augmented_embedder.embedding.weight.data),dim=1)[:3])	
	    
        self.step_for_save += 1
        if self.step_for_save % 10 == 0 and self.manual_eval:
            print("Save")
            np.save(self.output_dir + "/preds.npy", self.pred_val)
            np.save(self.output_dir + "/labels.npy", self.true_val)

        if not self.manual_eval and self.step_for_save % 4000 == 0:
            try:
                torch.save(self.model.augmented_embedder.state_dict(), self.output_dir + "/augmented_embedder.pth")
                np.save(self.output_dir + "/embedder_weights.npy", self.model.augmented_embedder.embedding.weight.data.clone().detach().cpu().numpy())
            except:
                pass
            try:
                torch.save(self.lm_head_reg.state_dict(), self.output_dir + "/lm_head_reg.pth")
                np.save(self.output_dir + "/lm_head_reg.npy", self.lm_head_reg.weight.data.detach().cpu().numpy())             
            except:
                pass
            try:
                np.save(self.output_dir + "/encode_reg.npy", self.model.augmented_embedder.encode_reg.weight.data.detach().cpu().numpy())
            except:
                pass

        return CausalLMOutputWithPast(
            loss=loss,
            logits=logits,
            past_key_values=outputs.past_key_values,
            hidden_states=outputs.hidden_states,
            attentions=outputs.attentions,
        )

    def prepare_inputs_for_generation(
        self,
        input_ids,
        past_key_values=None,
        attention_mask=None,
        attention_mask_gist=None,
        inputs_embeds=None,
        gist_offset=None,
        first_time=False,
        **kwargs,
    ):
        # Don't trim the input ids if we are caching gists (indicated by
        # non-None gist_offset).
        first_time_with_gist_caching = first_time and gist_offset is not None
        if past_key_values and not first_time_with_gist_caching:
            input_ids = input_ids[:, -1:]

        # if `inputs_embeds` are passed, we only want to use them in the 1st
        # generation step
        if inputs_embeds is not None and past_key_values is None:
            model_inputs = {"inputs_embeds": inputs_embeds}
        else:
            model_inputs = {"input_ids": input_ids}

        model_inputs.update(
            {
                "past_key_values": past_key_values,
                "use_cache": kwargs.get("use_cache"),
                "attention_mask": attention_mask,
                "attention_mask_gist": attention_mask_gist,
            }
        )

        return model_inputs

    @staticmethod
    def _reorder_cache(past_key_values, beam_idx):
        reordered_past = ()
        for layer_past in past_key_values:
            reordered_past += (
                tuple(
                    past_state.index_select(0, beam_idx) for past_state in layer_past
                ),
            )
        return reordered_past
